{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ic4_occAAiAT"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ioaprt5q5US7"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "yCl0eTNH5RS3"
      },
      "outputs": [],
      "source": [
        "#@title MIT License\n",
        "#\n",
        "# Copyright (c) 2017 François Chollet\n",
        "#\n",
        "# Permission is hereby granted, free of charge, to any person obtaining a\n",
        "# copy of this software and associated documentation files (the \"Software\"),\n",
        "# to deal in the Software without restriction, including without limitation\n",
        "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n",
        "# and/or sell copies of the Software, and to permit persons to whom the\n",
        "# Software is furnished to do so, subject to the following conditions:\n",
        "#\n",
        "# The above copyright notice and this permission notice shall be included in\n",
        "# all copies or substantial portions of the Software.\n",
        "#\n",
        "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
        "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
        "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n",
        "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
        "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n",
        "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n",
        "# DEALINGS IN THE SOFTWARE."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ItXfxkxvosLH"
      },
      "source": [
        "# Базовая классификация текста"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hKY4XMc9o8iB"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/keras/text_classification\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />Смотрите на TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ru/tutorials/keras/text_classification.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Запустите в Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ru/tutorials/keras/text_classification.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />Изучайте код на GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ru/tutorials/keras/text_classification.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Скачайте ноутбук</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "728d80d02ce9"
      },
      "source": [
        "Note: Вся информация в этом разделе переведена с помощью русскоговорящего Tensorflow сообщества на общественных началах. Поскольку этот перевод не является официальным, мы не гарантируем что он на 100% аккуратен и соответствует [официальной документации на английском языке](https://www.tensorflow.org/?hl=en). Если у вас есть предложение как исправить этот перевод, мы будем очень рады увидеть pull request в [tensorflow/docs](https://github.com/tensorflow/docs) репозиторий GitHub. Если вы хотите помочь сделать документацию по Tensorflow лучше (сделать сам перевод или проверить перевод подготовленный кем-то другим), напишите нам на [docs-ru@tensorflow.org list](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ru)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Eg62Pmz3o83v"
      },
      "source": [
        "В этом руководстве демонстрируется классификация текста, начиная с простых текстовых файлов, хранящихся на диске. Вы научите бинарный классификатор выполнять анализ тональности набора данных IMDB. В конце руководства есть упражнение, которое вы можете попробовать научить мультиклассовый классификатор предсказывать тег для вопроса программирования на Stack Overflow."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8RZOuS9LWQvv"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import re\n",
        "import shutil\n",
        "import string\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras import losses\n",
        "from tensorflow.keras import preprocessing\n",
        "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6-tTFS04dChr"
      },
      "outputs": [],
      "source": [
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NBTI1bi8qdFV"
      },
      "source": [
        "## Анализ настроений\n",
        "\n",
        "Это руководство обучает модель анализа настроений классифицировать обзоры фильмов как *положительные* или *отрицательные* на основе текста. Это пример *бинарной* - или двухклассовой классификации, важного и широко применяемого вида машинного обучения.\n",
        "\n",
        "Вы будете использовать [Большой набор данных обзоров фильмов](https://ai.stanford.edu/~amaas/data/sentiment/), который содержит текст 50 000 обзоров фильмов из [База данных фильмов в Интернете](https://www.imdb.com/). Они разделены на 25 000 отзывов для обучения и 25 000 отзывов для тестирования. Наборы для обучения и тестирования *сбалансированы*, что означает, что они содержат равное количество положительных и отрицательных отзывов.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iAsKG535pHep"
      },
      "source": [
        "### Загрузка и изучение набора данных IMDB\n",
        "\n",
        "Давайте загрузим и извлечем набор данных, а затем исследуем структуру каталогов."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k7ZYnuajVlFN"
      },
      "outputs": [],
      "source": [
        "url = \"https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
        "\n",
        "dataset = tf.keras.utils.get_file(\"aclImdb_v1.tar.gz\", url,\n",
        "                                    untar=True, cache_dir='.',\n",
        "                                    cache_subdir='')\n",
        "\n",
        "dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "355CfOvsV1pl"
      },
      "outputs": [],
      "source": [
        "os.listdir(dataset_dir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7ASND15oXpF1"
      },
      "outputs": [],
      "source": [
        "train_dir = os.path.join(dataset_dir, 'train')\n",
        "os.listdir(train_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ysMNMI1CWDFD"
      },
      "source": [
        "Каталоги `aclImdb/train/pos` и `aclImdb/train/neg` содержат множество текстовых файлов, каждый из которых представляет собой отдельный обзор фильма. Давайте посмотрим на один из них."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R7g8hFvzWLIZ"
      },
      "outputs": [],
      "source": [
        "sample_file = os.path.join(train_dir, 'pos/1181_9.txt')\n",
        "with open(sample_file) as f:\n",
        "  print(f.read())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mk20TEm6ZRFP"
      },
      "source": [
        "### Загрузка датасета\n",
        "\n",
        "Теперь вы загрузите данные с диска и подготовите их в формате, подходящем для обучения. Для этого вы будете использовать полезную утилиту [text_dataset_from_directory](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory), которая ожидает следующую структуру каталогов.\n",
        "\n",
        "```\n",
        "main_directory/\n",
        "...class_a/\n",
        "......a_text_1.txt\n",
        "......a_text_2.txt\n",
        "...class_b/\n",
        "......b_text_1.txt\n",
        "......b_text_2.txt\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nQauv38Lnok3"
      },
      "source": [
        "Чтобы подготовить набор данных для бинарной классификации, вам потребуются две папки на диске, `class_a` и` class_b`. Это будут положительные и отрицательные обзоры фильмов, которые можно найти в  `aclImdb/train/pos` и `aclImdb/train/neg`. Поскольку набор данных IMDB содержит дополнительные папки, вы удалите их перед использованием этой утилиты."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VhejsClzaWfl"
      },
      "outputs": [],
      "source": [
        "remove_dir = os.path.join(train_dir, 'unsup')\n",
        "shutil.rmtree(remove_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "95kkUdRoaeMw"
      },
      "source": [
        "Затем вы воспользуетесь утилитой `text_dataset_from_directory`, чтобы создать размеченный `tf.data.Dataset`. [tf.data](https://www.tensorflow.org/guide/data) - мощный набор инструментов для работы с данными.\n",
        "\n",
        "При проведении эксперимента с машинным обучением рекомендуется разделить набор данных на три части: [train](https://developers.google.com/machine-learning/glossary#training_set), [validation](https://developers.google.com/machine-learning/glossary#validation_set) и [test](https://developers.google.com/machine-learning/glossary#test-set).\n",
        "\n",
        "Набор данных IMDB уже разделен на обучающий и тестовый, но в нем отсутствует набор для проверки. Давайте создадим набор для проверки, используя разделение обучающих данных 80:20 с помощью аргумент `validation_split`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nOrK-MTYaw3C"
      },
      "outputs": [],
      "source": [
        "batch_size = 32\n",
        "seed = 42\n",
        "\n",
        "raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
        "    'aclImdb/train', \n",
        "    batch_size=batch_size, \n",
        "    validation_split=0.2, \n",
        "    subset='training', \n",
        "    seed=seed)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Y33oxOUpYkh"
      },
      "source": [
        "Как вы видели выше, в папке для обучения 25 000 примеров, 80%(или 20 000) которых вы будете использовать  для обучения. Вы можете обучить модель, передав набор данных непосредственно в `model.fit`. Если вы новичок в `tf.data`, ниже показан пример того, как можно получить и просмотреть несколько примеров из готового датасета."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "51wNaPPApk1K"
      },
      "outputs": [],
      "source": [
        "for text_batch, label_batch in raw_train_ds.take(1):\n",
        "  for i in range(3):\n",
        "    print(\"Review\", text_batch.numpy()[i])\n",
        "    print(\"Label\", label_batch.numpy()[i])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JWq1SUIrp1a-"
      },
      "source": [
        "Обратите внимание, что обзоры содержат необработанный текст(с пунктуацией и случайными HTML-тегами, такими как `<br/>`). В следующем разделе вы узнаете, как с ними справиться.\n",
        "\n",
        "Метки равны 0 или 1. Чтобы увидеть, какие из них соответствуют положительным и отрицательным обзорам, вы можете проверить свойство `class_names` в наборе данных."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MlICTG8spyO2"
      },
      "outputs": [],
      "source": [
        "print(\"Label 0 corresponds to\", raw_train_ds.class_names[0])\n",
        "print(\"Label 1 corresponds to\", raw_train_ds.class_names[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pbdO39vYqdJr"
      },
      "source": [
        "Далее вы создадите набор данных для проверки и тестирования. Вы будете использовать оставшиеся 5000 отзывов из обучающего набора для проверки."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SzxazN8Hq1pF"
      },
      "source": [
        "Примечание. При использовании аргументов `validation_split` и `subset` убедитесь, что вы указали случайное начальное число(random seed) или передали `shuffle=False`, чтобы проверочные и обучающие пакеты не перекрывались."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JsMwwhOoqjKF"
      },
      "outputs": [],
      "source": [
        "raw_val_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
        "    'aclImdb/train', \n",
        "    batch_size=batch_size, \n",
        "    validation_split=0.2, \n",
        "    subset='validation', \n",
        "    seed=seed)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rdSr0Nt3q_ns"
      },
      "outputs": [],
      "source": [
        "raw_test_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
        "    'aclImdb/test', \n",
        "    batch_size=batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kDA_Lu2PoGyP"
      },
      "source": [
        "Примечание. API предварительной обработки, используемые в следующем разделе, являются экспериментальными в TensorFlow 2.3 и могут быть изменены."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qJmTiO0IYAjm"
      },
      "source": [
        "### Подготовка набор данных для обучения\n",
        "\n",
        "Теперь вы стандартизируете, токенизируете и векторизуете данные, используя полезный слой `preprocessing.TextVectorization`.\n",
        "\n",
        "Стандартизация относится к предварительной обработке текста, обычно она используется для удаления знаков препинания или элементов HTML. Токенизация относится к разделению строк на токены(например, разделение предложения на отдельные слова путем разделения на пробелы). Векторизация относится к преобразованию токенов в числа, чтобы их можно было передать в нейронную сеть. Все эти задачи могут быть выполнены с помощью этого слоя.\n",
        "\n",
        "Как вы видели выше, обзоры содержат различные HTML-теги, такие как `<br />`. Эти теги не будут удалены стандартизатором в слое `TextVectorization`, так как по умолчанию `TextVectorization` преобразует текст в нижний регистр и удаляет знаки пунктуации, но не удаляет HTML. Вы напишете специальную функцию стандартизации для удаления HTML."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZVcHl-SLrH-u"
      },
      "source": [
        "Примечание. Чтобы предотвратить [перекос при обучении/тестировании](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew)(также известный как перекос при обучении/обслуживании), важно сделать идентичную предварительную обработку данных как для обучения, так и для обучения и тестирования. Чтобы упростить этот процесс, слой `TextVectorization` можно включить непосредственно в вашу модель, как показано далее в этом руководстве."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SDRI_s_tX1Hk"
      },
      "outputs": [],
      "source": [
        "def custom_standardization(input_data):\n",
        "  lowercase = tf.strings.lower(input_data)\n",
        "  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')\n",
        "  return tf.strings.regex_replace(stripped_html,\n",
        "                                  '[%s]' % re.escape(string.punctuation),\n",
        "                                  '')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d2d3Aw8dsUux"
      },
      "source": [
        "Затем вы создадите слой `TextVectorization`. Вы будете использовать этот слой для стандартизации, токенизации и векторизации данных. Укажите значение `int` для аргумента `output_mode` , чтобы создать уникальные целочисленные индексы для каждого токена.\n",
        "\n",
        "Обратите внимание, что вы используете функцию разделения по умолчанию и пользовательскую функцию стандартизации, которую вы определили выше. Вы также определите некоторые константы для модели, например, значение `sequence_length`, которое заставит слой дополнять или обрезать последовательности до значения `sequence_length`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-c76RvSzsMnX"
      },
      "outputs": [],
      "source": [
        "max_features = 10000\n",
        "sequence_length = 250\n",
        "\n",
        "vectorize_layer = TextVectorization(\n",
        "    standardize=custom_standardization,\n",
        "    max_tokens=max_features,\n",
        "    output_mode='int',\n",
        "    output_sequence_length=sequence_length)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vlFOpfF6scT6"
      },
      "source": [
        "Затем вы вызовете `adapt`, чтобы подогнать состояние уровня предварительной обработки к набору данных. Это заставит модель построить индекс строки -> целые числа."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lAhdjK7AtroA"
      },
      "source": [
        "Примечание: при вызове `adapt` важно использовать только тренировочный датасет, использование тестового датасета приведет к утечке информации."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GH4_2ZGJsa_X"
      },
      "outputs": [],
      "source": [
        "# Создайте набор данных только из текста(без меток), затем вызовите adapt\n",
        "train_text = raw_train_ds.map(lambda x, y: x)\n",
        "vectorize_layer.adapt(train_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SHQVEFzNt-K_"
      },
      "source": [
        "Давайте создадим функцию, чтобы увидеть результат использования слоя предварительной обработки для некоторых данных."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SCIg_T50wOCU"
      },
      "outputs": [],
      "source": [
        "def vectorize_text(text, label):\n",
        "  text = tf.expand_dims(text, -1)\n",
        "  return vectorize_layer(text), label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XULcm6B3xQIO"
      },
      "outputs": [],
      "source": [
        "# получить пакет(из 32 отзывов и меток) из набора данных\n",
        "text_batch, label_batch = next(iter(raw_train_ds))\n",
        "first_review, first_label = text_batch[0], label_batch[0]\n",
        "print(\"Review\", first_review)\n",
        "print(\"Label\", raw_train_ds.class_names[first_label])\n",
        "print(\"Vectorized review\", vectorize_text(first_review, first_label))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6u5EX0hxyNZT"
      },
      "source": [
        "Как вы можете видеть выше, каждый токен был заменен целым числом. Вы можете найти токен(строку), которому соответствует каждое целое число, вызвав метод уровня `.get_vocabulary()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kRq9hTQzhVhW"
      },
      "outputs": [],
      "source": [
        "print(\"1287 ---> \",vectorize_layer.get_vocabulary()[1287])\n",
        "print(\" 313 ---> \",vectorize_layer.get_vocabulary()[313])\n",
        "print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XD2H6utRydGv"
      },
      "source": [
        "Вы почти готовы обучать свою модель. В качестве последнего шага предварительной обработки вы примените слой `TextVectorization`, ко всем датасетам - для обучения, проверки и тестирования."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2zhmpeViI1iG"
      },
      "outputs": [],
      "source": [
        "train_ds = raw_train_ds.map(vectorize_text)\n",
        "val_ds = raw_val_ds.map(vectorize_text)\n",
        "test_ds = raw_test_ds.map(vectorize_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YsVQyPMizjuO"
      },
      "source": [
        "### Настройка датасета для повышения производительности\n",
        "\n",
        "Существует два важных метода, которые вы должны использовать при загрузке данных, чтобы убедиться, что ввод-вывод не становится блокирующим.\n",
        "\n",
        "`.cache()` сохраняет данные в памяти после их загрузки с диска. Этот метода гарантирует, что набор данных не станет узким местом при обучении вашей модели. Если ваш набор данных слишком велик для размещения в памяти, вы также можете использовать этот метод для создания производительного кеша на диске, который более эффективен для чтения, чем многие небольшие файлы.\n",
        "\n",
        "`.prefetch()` перекрывает предварительную обработку данных и выполнение модели во время обучения.\n",
        "\n",
        "Вы можете узнать больше об обоих методах, а также о том, как кэшировать данные на диск, в [руководстве по производительности данных](https://www.tensorflow.org/guide/data_performance)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wMcs_H7izm5m"
      },
      "outputs": [],
      "source": [
        "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
        "\n",
        "train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
        "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
        "test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LLC02j2g-llC"
      },
      "source": [
        "### Создание модели"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dkQP6in8yUBR"
      },
      "outputs": [],
      "source": [
        "embedding_dim = 16"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xpKOoWgu-llD"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "  layers.Embedding(max_features + 1, embedding_dim),\n",
        "  layers.Dropout(0.2),\n",
        "  layers.GlobalAveragePooling1D(),\n",
        "  layers.Dropout(0.2),\n",
        "  layers.Dense(1)])\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6PbKQ6mucuKL"
      },
      "source": [
        "Для построения классификатора мы используем стек слоев:\n",
        "\n",
        "1. Первый слой - это `Embedding`. Этот слой принимает обзоры, закодированные в целочисленном формате, и ищет вектор внедрения для каждого индекса слова. Эти векторы изучаются по мере обучения модели. Векторы добавляют измерение к выходному массиву. В результате получаются следующая размерность: `(batch, sequence, embedding)`. Чтобы узнать больше о технике эмбеддинга, см. [Руководство по эмбеддингу слов](../text/word_embeddings.ipynb).\n",
        "2. Затем слой `GlobalAveragePooling1D`, который возвращает выходной вектор фиксированной длины для каждого примера путем усреднения по размерности. Это позволяет модели обрабатывать ввод переменной длины самым простым из возможных способов.\n",
        "3. Этот выходной вектор фиксированной длины передается по конвейеру через полносвязанный слой(`Dense`) с 16 скрытыми узлами.\n",
        "4. Последний слой - полносвязанный слой(`Dense`) с единственным выходным узлом."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L4EqVWg4-llM"
      },
      "source": [
        "### Функция потерь и оптимайзер\n",
        "\n",
        "Для обучения модели нужны функция потерь и оптимайзер. Поскольку это проблема бинарной классификации и модель выводит вероятность(единичный слой с активацией сигмоид), вы будете использовать функцию потерь `losses.BinaryCrossentropy`.\n",
        "\n",
        "Теперь настройте модель для использования оптимайзера и функции потерь:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mr0GP-cQ-llN"
      },
      "outputs": [],
      "source": [
        "model.compile(loss=losses.BinaryCrossentropy(from_logits=True),\n",
        "              optimizer='adam',\n",
        "              metrics=tf.metrics.BinaryAccuracy(threshold=0.0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "35jv_fzP-llU"
      },
      "source": [
        "### Обучение модели\n",
        "\n",
        "Вы обучите модель, передав объект `dataset` методу `fit`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tXSGrjWZ-llW"
      },
      "outputs": [],
      "source": [
        "epochs = 10\n",
        "history = model.fit(\n",
        "    train_ds,\n",
        "    validation_data=val_ds,\n",
        "    epochs=epochs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9EEGuDVuzb5r"
      },
      "source": [
        "### Оценка модели\n",
        "\n",
        "Посмотрим, как модель работает. Будут возвращены два значения. Величина потерь(число, которое представляет нашу ошибку, меньшие значения лучше) и точность.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zOMKywn4zReN"
      },
      "outputs": [],
      "source": [
        "loss, accuracy = model.evaluate(test_ds)\n",
        "\n",
        "print(\"Loss: \", loss)\n",
        "print(\"Accuracy: \", accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z1iEXVTR0Z2t"
      },
      "source": [
        "Этот довольно наивный подход обеспечивает точность около 86%."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ldbQqCw2Xc1W"
      },
      "source": [
        "### Создание графика точности и потерь\n",
        "\n",
        "`model.fit()` возвращает объект `History`, который содержит словарь со всем, что произошло во время обучения:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-YcvZsdvWfDf"
      },
      "outputs": [],
      "source": [
        "history_dict = history.history\n",
        "history_dict.keys()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1_CH32qJXruI"
      },
      "source": [
        "Есть четыре записи: по одной для каждой отслеживаемой метрики во время обучения и проверки. Вы можете использовать их для построения графика потерь при обучении и проверке для сравнения, а также для определения точности обучения и проверки:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2SEMeQ5YXs8z"
      },
      "outputs": [],
      "source": [
        "acc = history_dict['binary_accuracy']\n",
        "val_acc = history_dict['val_binary_accuracy']\n",
        "loss = history_dict['loss']\n",
        "val_loss = history_dict['val_loss']\n",
        "\n",
        "epochs = range(1, len(acc) + 1)\n",
        "\n",
        "# \"bo\" is for \"blue dot\"\n",
        "plt.plot(epochs, loss, 'bo', label='Training loss')\n",
        "# b is for \"solid blue line\"\n",
        "plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
        "plt.title('Training and validation loss')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Loss')\n",
        "plt.legend()\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z3PJemLPXwz_"
      },
      "outputs": [],
      "source": [
        "plt.plot(epochs, acc, 'bo', label='Training acc')\n",
        "plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
        "plt.title('Training and validation accuracy')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.legend(loc='lower right')\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hFFyCuJoXy7r"
      },
      "source": [
        "Точки на графике представляют потери и точность при обучении, а сплошные линии - потери и точность при оценке.\n",
        "\n",
        "Обратите внимание, что потери при обучении *уменьшаются* с каждой эпохой, а точность обучения *увеличивается* с каждой эпохой. Это ожидаемое поведение при использовании оптимизации градиентного спуска - он должен минимизировать ошибку на каждой итерации.\n",
        "\n",
        "Это не относится к потерям и точности при оценке - они достигают пика раньше точности обучения. Это пример переобучения: модель лучше работает с данными обучения, чем с данными, которых она никогда раньше не видела. После определенной точки модель чрезмерно оптимизирует и изучает признаки, *специфичные* для обучающих данных, которые не *характерны* для тестовых данных.\n",
        "\n",
        "В этом конкретном случае вы можете избежать переобучения, просто остановив тренировку модели, когда точность проверки больше не увеличивается. Один из способов сделать это - использовать обратный вызов [EarlyStopping](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping?version=nightly)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-to23J3Vy5d3"
      },
      "source": [
        "## Экспорт модели\n",
        "\n",
        "В приведенном выше коде вы применили слой `TextVectorization` к набору данных перед отправкой текста в модель. Если вы хотите, чтобы ваша модель могла работать с необработанными строками(например, чтобы упростить ее развертывание), вы можете включить слой `TextVectorization` в вашу модель. Для этого вы можете создать новую модель, используя только что обученные веса."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FWXsMvryuZuq"
      },
      "outputs": [],
      "source": [
        "export_model = tf.keras.Sequential([\n",
        "  vectorize_layer,\n",
        "  model,\n",
        "  layers.Activation('sigmoid')\n",
        "])\n",
        "\n",
        "export_model.compile(\n",
        "    loss=losses.BinaryCrossentropy(from_logits=False), optimizer=\"adam\", metrics=['accuracy']\n",
        ")\n",
        "\n",
        "# Протестируйте ее с помощью вызова raw_test_ds, который возвращает необработанные строки\n",
        "loss, accuracy = export_model.evaluate(raw_test_ds)\n",
        "print(accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TwQgoN88LoEF"
      },
      "source": [
        "### Вывод для новых данных\n",
        "\n",
        "Чтобы получить прогнозы для новых данных, вы можете просто вызвать `model.predict()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QW355HH5L49K"
      },
      "outputs": [],
      "source": [
        "examples = [\n",
        "  \"The movie was great!\",\n",
        "  \"The movie was okay.\",\n",
        "  \"The movie was terrible...\"\n",
        "]\n",
        "\n",
        "export_model.predict(examples)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MaxlpFWpzR6c"
      },
      "source": [
        "Включение логики предварительной обработки текста в модель позволяет экспортировать модель в продакшен, что упрощает развертывание и снижает вероятность [перекоса обучения/тестирования](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serve_skew).\n",
        "\n",
        "При выборе места для применения слоя `TextVectorization` следует учитывать разницу в производительности. Использование его вне модели позволяет выполнять асинхронную обработку на CPU и буферизацию данных при обучении на GPU. Таким образом, если вы тренируете свою модель на графическом процессоре, вы, вероятно, захотите использовать этот вариант, чтобы получить максимальную производительность при разработке модели, а затем переключитесь на включение слоя `TextVectorization` в вашу модель, когда вы будете готовы к развертыванию.\n",
        "\n",
        "Посмотрите этот [учебник](https://www.tensorflow.org/tutorials/keras/save_and_load), чтобы узнать больше о сохранении моделей."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eSSuci_6nCEG"
      },
      "source": [
        "## Упражнение: мультиклассовая классификация по вопросам со Stack Overflow\n",
        "\n",
        "В этом руководстве показано, как обучить бинарный классификатор на наборе данных IMDB. В качестве упражнения вы можете изменить этот ноутбук, чтобы обучить мультиклассовый классификатор предсказывать тег вопроса по программированию на [Stack Overflow](http://stackoverflow.com/).\n",
        "\n",
        "Мы подготовили для вас [набор данных](http://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz), содержащий несколько тысяч вопросов по программированию(например, «Как можно отсортировать словарь по значению в Python?»), размещенных на Stack Overflow. Каждый из них помечен ровно одним тегом(Python, CSharp, JavaScript или Java). Ваша задача - взять вопрос в качестве входных данных и предсказать соответствующий тег, в данном случае Python.\n",
        "\n",
        "Набор данных, с которым вы будете работать, содержит несколько тысяч вопросов, извлеченных из гораздо большего общедоступного набора данных Stack Overflow на [BigQuery](https://console.cloud.google.com/marketplace/details/stack-exchange/stack-overflow), который содержит более 17 миллионов сообщений.\n",
        "\n",
        "После загрузки датасета вы обнаружите, что он имеет структуру каталогов, аналогичную датасету IMDB, с которым вы работали ранее:\n",
        "\n",
        "```\n",
        "train/\n",
        "...python/\n",
        "......0.txt\n",
        "......1.txt\n",
        "...javascript/\n",
        "......0.txt\n",
        "......1.txt\n",
        "...csharp/\n",
        "......0.txt\n",
        "......1.txt\n",
        "...java/\n",
        "......0.txt\n",
        "......1.txt\n",
        "```\n",
        "\n",
        "Примечание: чтобы усложнить задачу классификации, мы заменили любые слова Python, CSharp, JavaScript или Java в вопросах программирования словом *blank*(так как многие вопросы содержат название языка).\n",
        "\n",
        "Чтобы выполнить это упражнение, вам следует изменить этот ноутбук для работы с набором данных Stack Overflow, внеся следующие изменения:\n",
        "\n",
        "1. В верхней части записной книжки обновите код, загружающий набор данных IMDB, на код для загрузки [набора данных Stack Overflow](http://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz). Поскольку набор данных Stack Overflow имеет аналогичную структуру каталогов, вам не нужно будет вносить много изменений.\n",
        "\n",
        "1. Измените последний слой вашей модели так, чтобы он читался как `Dense(4)`, поскольку теперь существует четыре выходных класса.\n",
        "\n",
        "1. При компиляции модели измените функцию потерь на [SparseCategoricalCrossentropy](https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy?version=nightly). Это правильная функция потерь для использования в задаче классификации нескольких классов, когда метки для каждого класса являются целыми числами (в нашем случае они могут быть *0*, *1*, *2* или *3*).\n",
        "\n",
        "1. После внесения этих изменений вы сможете обучать мультиклассовый классификатор.\n",
        "\n",
        "Если же у вас возникли проблемы, вы можете найти готовое решение [здесь](https://github.com/tensorflow/examples/blob/master/community/en/text_classification_solution.ipynb)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F0T5SIwSm7uc"
      },
      "source": [
        "## Узнать больше\n",
        "\n",
        "В этом руководстве была описана классификация текста с нуля. Чтобы узнать больше о рабочем процессе классификации текста в целом, мы рекомендуем прочитать [это руководство](https://developers.google.com/machine-learning/guides/text-classification/) от разработчиков Google."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "text_classification.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
